package com.mlamp.字典树;

import java.util.*;

public class 字典树 {

    public static void main(String[] args) {
        字典树 tree = new 字典树();
        List<String> words = Arrays.asList("hello", "helloword", "hellooo", "helap");
        for (String item : words) {
            tree.insert(item);
        }
        Collection<String> hell = tree.queryAllByPrefix("hell");
        if (hell != null) {
            for (String item : hell) {
                System.out.println(item);
            }
        }

        System.out.println(tree.exist("helap"));
    }

    private static final int size = 26;
    private Node root;
    private List<String> keyword = new ArrayList<String>();
    private int keywordSize = 0;

    private static final class Node {
        boolean isString;
        int num;
        char ch;
        int location;
        Node[] child;

        public Node() {
            child = new Node[size];
            isString = false;
            num = 1;
            ch = ' ';
            location = -1;
        }

        public Node(char ch) {
            child = new Node[size];
            isString = false;
            num = 1;
            this.ch = ch;
            location = -1;
        }

        @Override
        public String toString() {
            return "Node{" +
                    "isString=" + isString +
                    ", num=" + num +
                    ", ch=" + ch +
                    ", child=" + Arrays.toString(child) +
                    '}';
        }
    }

    public 字典树() {
        root = new Node();
    }


    public void insertA(String str) {
        if (str == null || str.isEmpty()) return;
        keyword.add(str);
        int location = keywordSize++;
        Node pNode = this.root;
        for (int i = 0; i < str.length(); i++) {
            int loc = str.charAt(i) - 'a';
            if (pNode.child[loc] == null) {
                Node tempNode = new Node(str.charAt(i));
                pNode.child[loc] = tempNode;
            } else {
                pNode.child[loc].num++;
            }
            pNode = pNode.child[loc];
        }
        pNode.isString = true;
        pNode.location = location;
    }


    public void insertB(String str) {
        if (str == null || str.isEmpty()) throw new IllegalArgumentException("invalid input");
        keyword.add(str);
        int location = keywordSize++;
        Node pNode = this.root;
        for (int i = 0; i < str.length(); i++) {
            int loc = str.charAt(i) - 'a';
            if (pNode.child[loc] == null) {
                Node tmp = new Node(str.charAt(i));
                pNode.child[loc] = tmp;
            } else {
                pNode.child[loc].num++;
            }
            pNode = pNode.child[loc];
        }
        pNode.isString = true;
        pNode.location = location;
    }


    public void insert(String str) {
        if (str == null || str.isEmpty()) {
            return;
        }
        keyword.add(str);
        int location = keywordSize++;
        Node pNode = this.root;
        for (int index = 0; index < str.length(); index++) {
            int loc = str.charAt(index) - 'a';
            if (pNode.child[loc] == null) {
                Node tempNode = new Node(str.charAt(index));
                pNode.child[loc] = tempNode;
            } else {
                pNode.child[loc].num++;
            }
            pNode = pNode.child[loc];
        }
        pNode.isString = true;
        pNode.location = location;
    }

    public boolean exist(String word) {
        Node pNode = root;
        for (int index = 0; index < word.length(); index++) {
            int loc = word.charAt(index) - 'a';
            if (pNode.child[loc] == null) {
                return false;
            }
            pNode = pNode.child[loc];
        }
        return pNode.isString;
    }


    public boolean existA(String word) {
        Node pNode = root;
        for (int i = 0; i < word.length(); i++) {
            int loc = word.charAt(i) - 'a';
            if (pNode.child[loc] == null) return false;
            pNode = pNode.child[loc];
        }
        return pNode.isString;
    }


    public Collection<String> queryAllByPrefix(String prefix) {
        Node pNode = root;
        Set<String> result = new HashSet<>();
        for (int index = 0; index < prefix.length(); index++) {
            int loc = prefix.charAt(index) - 'a';
            if (pNode.child[loc] == null) {
                return Collections.emptyList();
            }
            pNode = pNode.child[loc];
        }
        Node tmpNode = pNode;
        //if (tmpNode.isString) result.add(keyword.get(tmpNode.location));
        preWalkPlus(tmpNode, result);
        return result;
    }


    public Collection<String> queryAllByPrefixB(String prefix) {
        Node pNode = root;
        Set<String> result = new HashSet<>();
        for (int i = 0; i < prefix.length(); i++) {
            int loc = prefix.charAt(i) - 'a';
            if (pNode.child[loc] == null) return Collections.emptyList();
            pNode = pNode.child[loc];
        }
        Node tmpNode = pNode;
        preWalkPlusB(tmpNode, result);
        return result;
    }

    private void preWalkPlusB(Node tmpNode, Set<String> result) {
        if (tmpNode.isString) result.add(keyword.get(tmpNode.location));
        for (int i = 0; i < size; i++) {
            if (tmpNode.child[i] != null) {
                preWalkPlusB(root.child[i], result);
            }
        }
    }


    public void preWalkPlus(Node root, Set<String> result) {
        if (root.isString) {
            result.add(keyword.get(root.location));
        }
        for (int index = 0; index < size; index++) {
            if (root.child[index] != null) {
                preWalkPlus(root.child[index], result);
            }
        }
    }


    public int countPrefix(String prefix) {
        Node pNode = this.root;
        for (int index = 0; index < prefix.length(); index++) {
            int loc = prefix.charAt(index) - 'a';
            if (pNode.child[loc] == null) {
                return 0;
            } else pNode = pNode.child[loc];
        }
        return pNode.num;
    }

    public void preWalk(Node root) {
        Node pNode = root;
        for (int index = 0; index < size; index++) {
            if (pNode.child[index] != null) {
                System.out.println((char) ('a' + index) + "--");
                preWalk(pNode.child[index]);
            }
        }
    }

    public void preWalk2(Node root) {
        Node pNode = root;
        for (int index = 0; index < size; index++) {
            if (pNode.child[index] != null) {
                System.out.print((char) ('a' + index) + "--");
                preWalk2(pNode.child[index]);
            }
        }
    }


}
